{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conversion to ONNX from keras model using tf2onnx python api "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get image\n",
    "!wget -q https://raw.githubusercontent.com/onnx/tensorflow-onnx/main/tests/ade20k.jpg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install tensorflow tf2onnx onnxruntime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.applications.resnet50 import ResNet50\n",
    "from tensorflow.keras.preprocessing import image\n",
    "from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions\n",
    "import numpy as np\n",
    "import onnxruntime\n",
    "\n",
    "img_path = 'ade20k.jpg'\n",
    "\n",
    "img = image.load_img(img_path, target_size=(224, 224))\n",
    "\n",
    "x = image.img_to_array(img)\n",
    "x = np.expand_dims(x, axis=0)\n",
    "x = preprocess_input(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the keras model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keras Predicted: [('n04285008', 'sports_car', 0.34311807), ('n02974003', 'car_wheel', 0.28819188), ('n03100240', 'convertible', 0.10169428)]\n",
      "INFO:tensorflow:Assets written to: /tmp/resnet50/assets\n"
     ]
    }
   ],
   "source": [
    "model = ResNet50(weights='imagenet')\n",
    "\n",
    "preds = model.predict(x)\n",
    "print('Keras Predicted:', decode_predictions(preds, top=3)[0])\n",
    "model.save(os.path.join(\"/tmp\", model.name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert to ONNX using the Python API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ipython/.local/lib/python3.7/site-packages/tf2onnx/tf_loader.py:558: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n"
     ]
    }
   ],
   "source": [
    "import tf2onnx\n",
    "import onnxruntime as rt\n",
    "\n",
    "spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name=\"input\"),)\n",
    "output_path = model.name + \".onnx\"\n",
    "\n",
    "model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)\n",
    "output_names = [n.name for n in model_proto.graph.output]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the ONNX model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ONNX Predicted: [('n04285008', 'sports_car', 0.34311718), ('n02974003', 'car_wheel', 0.2881928), ('n03100240', 'convertible', 0.10169421)]\n"
     ]
    }
   ],
   "source": [
    "providers = ['CPUExecutionProvider']\n",
    "m = rt.InferenceSession(output_path, providers=providers)\n",
    "onnx_pred = m.run(output_names, {\"input\": x})\n",
    "\n",
    "print('ONNX Predicted:', decode_predictions(onnx_pred[0], top=3)[0])\n",
    "\n",
    "# make sure ONNX and keras have the same results\n",
    "np.testing.assert_allclose(preds, onnx_pred[0], rtol=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert to ONNX using the command line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.7/runpy.py:125: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour\n",
      "  warn(RuntimeWarning(msg))\n",
      "2021-02-22 16:19:56,658 - WARNING - '--tag' not specified for saved_model. Using --tag serve\n",
      "2021-02-22 16:20:01,809 - INFO - Signatures found in model: [serving_default].\n",
      "2021-02-22 16:20:01,809 - WARNING - '--signature_def' not specified, using first signature: serving_default\n",
      "2021-02-22 16:20:01,810 - INFO - Output names: ['predictions']\n",
      "WARNING:tensorflow:From /home/ipython/.local/lib/python3.7/site-packages/tf2onnx/tf_loader.py:558: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "2021-02-22 16:20:03,659 - WARNING - From /home/ipython/.local/lib/python3.7/site-packages/tf2onnx/tf_loader.py:558: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "2021-02-22 16:20:04,540 - INFO - Using tensorflow=2.4.1, onnx=1.8.1, tf2onnx=1.9.0/ed89ef\n",
      "2021-02-22 16:20:04,540 - INFO - Using opset <onnx, 13>\n",
      "2021-02-22 16:20:07,401 - INFO - Computed 0 values for constant folding\n",
      "2021-02-22 16:20:13,830 - INFO - Optimizing ONNX model\n",
      "2021-02-22 16:20:15,152 - INFO - After optimization: Add -1 (18->17), BatchNormalization -53 (53->0), Cast -1 (1->0), Concat +1 (0->1), Const -162 (271->109), Identity -57 (57->0), Reshape -1 (1->0), Split +1 (0->1), Transpose -213 (215->2)\n",
      "2021-02-22 16:20:15,259 - INFO - \n",
      "2021-02-22 16:20:15,259 - INFO - Successfully converted TensorFlow model /tmp/resnet50 to ONNX\n",
      "2021-02-22 16:20:15,259 - INFO - Model inputs: ['input_1:0']\n",
      "2021-02-22 16:20:15,259 - INFO - Model outputs: ['predictions']\n",
      "2021-02-22 16:20:15,259 - INFO - ONNX model is saved at /tmp/resnet50.onnx\n"
     ]
    }
   ],
   "source": [
    "!python -m tf2onnx.convert --opset 13 \\\n",
    "    --saved-model {os.path.join(\"/tmp\", model.name)} \\\n",
    "    --output  {os.path.join(\"/tmp\", model.name + \".onnx\")}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
